在先前的課程中,我們專注於 元素級運算 (例如對矩陣執行基本的 ReLU)。這些運算屬於 記憶體限制型 因為顯示卡花費更多時間將資料從高頻寬記憶體(HBM)移動到暫存器,而非進行數學運算。
1. 為何 GEMM 至關重要
一般矩陣乘法(GEMM)的計算複雜度為 $O(N^3)$,但僅需 $O(N^2)$ 的記憶體存取。這讓我們能以巨大的算術吞吐量隱藏記憶體延遲,使其成為大型語言模型(LLMs)的「心臟」。
2. 二維記憶體表示
實際的記憶體是 1 維的。要表示一個二維張量,我們使用 步幅(Stride)。一個常見的生產環境陷阱是 假設張量是連續的。若你在指標運算中混淆了列與行的步幅,將會讀取到「幽靈資料」或引發記憶體違規。
3. 分塊泛化
Triton 透過從 單一指標 轉向 指標塊。藉由使用二維分塊(例如 $16 \times 16$),我們可利用 資料重用 高速 SRAM 中的特性,使資料保持『熱態』,以便在寫回全域記憶體前,進行融合運算,如偏置加法或激活函數。
main.py
TERMINALbash — 80x24
> Ready. Click "Run" to execute.
>
QUESTION 1
Why is an elementwise ReLU on a large matrix considered 'memory-bound'?
The ReLU function requires complex transcendental math.
The ratio of arithmetic operations to memory loads is very low (1:1).
Matrices are naturally stored in CPU memory only.
Triton cannot process non-linear activations.
✅ Correct!
Correct! Because we perform only one operation per element loaded, the hardware spends most of its time waiting for the bus.❌ Incorrect
Arithmetic intensity is the ratio of work to memory access. Elementwise ops have very low intensity.QUESTION 2
What is the result of 'The Stride Trap' in production kernels?
The kernel runs significantly faster but with less precision.
Memory access violations or corrupted output due to incorrect address calculation on non-contiguous tensors.
The GPU automatically corrects the indexing using L2 cache.
The tensor is forced into a 1D shape by the compiler.
✅ Correct!
Yes. Assuming contiguity (stride=1) when a tensor is sliced or transposed leads to reading the wrong memory offsets.❌ Incorrect
Triton requires explicit stride handling; it won't 'guess' the layout if your math assumes contiguity.QUESTION 3
How does Triton represent a 2D tile of pointers?
By using a nested Python list of integers.
By broadcasting a 1D column vector and a 1D row vector of offsets together.
By launching multiple 1D kernels sequentially.
By allocating a special 2D register file.
✅ Correct!
Correct. `offs_m[:, None] + offs_n[None, :]` creates a 2D coordinate grid used for block loading.❌ Incorrect
Triton uses broadcasting to efficiently generate multidimensional pointer grids in a single program instance.QUESTION 4
Which operation benefits most from the O(N³) complexity shift to hide memory latency?
Vector Addition
Matrix Multiplication (GEMM)
Sigmoid Activation
Global Average Pooling
✅ Correct!
GEMM is compute-bound, meaning it does enough math to justify the cost of loading the data tiles.❌ Incorrect
The other options are O(N) or O(N²), which typically remain memory-bound.QUESTION 5
List three kernels in your current workflow that launch multiple PyTorch ops and might benefit from fusion.
Linear -> Bias -> ReLU; LayerNorm -> Dropout; Softmax -> Masking.
Print -> Log -> Sleep.
DataLoader -> Augmentation -> Storage.
These ops cannot be fused in Triton.
✅ Correct!
Reference Answer: 1. Linear -> ReLU (Common MLP block). 2. LayerNorm -> Dropout (Transformer residual). 3. Softmax -> Masking (Attention mechanism). Fusing these avoids intermediate HBM writes.❌ Incorrect
Look for sequences where a large tensor is modified by simple elementwise or reduction steps.Case Study: The Contiguity Crisis
Debugging non-contiguous tensor access in production
A developer writes a custom Triton kernel for a Linear Layer. On standard training data, it works perfectly. However, during inference, the input tensor is frequently 'sliced' (e.g., `x[:, :hidden_dim/2]`), which changes its stride without changing its memory layout. The kernel begins outputting 'NaN' and random noise.
Q
Why did the kernel fail when the input was sliced?
Solution:
Slicing usually creates a non-contiguous view. If the kernel assumed the row stride was equal to the number of columns (width), but the physical memory jump to the next row remained the original width, the kernel would read 'stale' data from the unsliced portion of memory.
Slicing usually creates a non-contiguous view. If the kernel assumed the row stride was equal to the number of columns (width), but the physical memory jump to the next row remained the original width, the kernel would read 'stale' data from the unsliced portion of memory.
Q
How should the pointer arithmetic be updated to handle this?
Solution:
The kernel must accept `stride_m` and `stride_n` as arguments. Instead of `ptr = base + i * width + j`, it must use `ptr = base + i * stride_m + j * stride_n` to respect the actual memory mapping.
The kernel must accept `stride_m` and `stride_n` as arguments. Instead of `ptr = base + i * width + j`, it must use `ptr = base + i * stride_m + j * stride_n` to respect the actual memory mapping.